import pandas as pd
import numpy as np
import json
from sklearn.metrics import precision_recall_fscore_support, roc_curve, auc
from sklearn.metrics import hamming_loss, accuracy_score, multilabel_confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict
import time
import os
from pathlib import Path
import traceback

# Set global font sizes
plt.rcParams['font.size'] = 14  # Base font size
plt.rcParams['axes.titlesize'] = 16
plt.rcParams['axes.labelsize'] = 14
plt.rcParams['xtick.labelsize'] = 12
plt.rcParams['ytick.labelsize'] = 12
plt.rcParams['legend.fontsize'] = 12

def get_probabilities_for_label(predictions, label):
    """Get probability for a specific label"""
    try:
        if isinstance(predictions, str):
            try:
                pred_dict = json.loads(predictions.replace("'", '"'))
                return float(pred_dict.get(label, 0))
            except json.JSONDecodeError:
                if "|" in predictions:
                    labels = [l.strip() for l in predictions.split("|")]
                    return 1.0 if label in labels else 0.0
                return 1.0 if predictions.strip() == label else 0.0
        elif isinstance(predictions, dict):
            return float(predictions.get(label, 0))
        return 0.0
    except:
        return 0.0

def get_model_colors():
    """Get consistent color mapping for models"""
    return {
        'ConflLlama-Standard': '#2ca02c',  # Green
        'ConflLlama-Alt': '#d62728'        # Red
    }

def print_progress(message, indent=0):
    """Print formatted progress message with timestamp"""
    indent_str = "  " * indent
    timestamp = time.strftime("%H:%M:%S", time.localtime())
    print(f"[{timestamp}] {indent_str}{message}")

def sanitize_filename(name):
    """Convert model name to valid filename"""
    invalid_chars = [':', '.', '/', '\\', ' ']
    filename = name
    for char in invalid_chars:
        filename = filename.replace(char, '_')
    return filename

def parse_prediction(pred_str, threshold=0.5):
    """Parse prediction from various formats, including Unknown"""
    try:
        if isinstance(pred_str, str):
            try:
                # Try parsing as JSON
                pred_dict = json.loads(pred_str.replace("'", '"'))
                if isinstance(pred_dict, dict):
                    valid_preds = [(k, float(v)) for k, v in pred_dict.items() 
                                 if v is not None and str(v).replace('.', '').isdigit()]
                    if valid_preds:
                        valid_preds.sort(key=lambda x: x[1], reverse=True)
                        return valid_preds[0][0], [pred[0] for pred in valid_preds if pred[1] >= threshold]
            except json.JSONDecodeError:
                # Try parsing as pipe-separated format
                if "|" in pred_str:
                    labels = [label.strip() for label in pred_str.split("|")]
                    valid_labels = [label for label in labels if label]  # Keep Unknown
                    if valid_labels:
                        return valid_labels[0], valid_labels
                # Try parsing as single label
                clean_label = pred_str.strip()
                if clean_label:  # Keep Unknown
                    return clean_label, [clean_label]
        return "Unknown", ["Unknown"]
    except Exception as e:
        print_progress(f"Error parsing prediction: {str(e)}", indent=2)
        return "Unknown", ["Unknown"]

def get_gold_standard_labels(row):
    """Extract gold standard labels, including Unknown"""
    labels = []
    for col in ['attacktype1_txt', 'attacktype2_txt', 'attacktype3_txt']:
        if pd.notna(row[col]):
            labels.append(row[col])
    return labels

def calculate_single_label_metrics(y_true, y_pred, labels):
    """Calculate metrics for single-label classification"""
    precision, recall, f1, support = precision_recall_fscore_support(y_true, y_pred, 
                                                                   labels=labels, 
                                                                   zero_division=0)
    accuracy = accuracy_score(y_true, y_pred)
    
    metrics = {}
    for i, label in enumerate(labels):
        metrics[label] = {
            'Precision': precision[i],
            'Recall': recall[i],
            'F1-score': f1[i],
            'Support': support[i]
        }
    
    # Calculate weighted averages
    total_support = sum(support)
    weighted_precision = sum(p * s / total_support for p, s in zip(precision, support))
    weighted_recall = sum(r * s / total_support for r, s in zip(recall, support))
    weighted_f1 = sum(f * s / total_support for f, s in zip(f1, support))
    
    metrics['macro_avg'] = {
        'Precision': np.mean(precision),
        'Recall': np.mean(recall),
        'F1-score': np.mean(f1),
        'Accuracy': accuracy
    }
    
    metrics['weighted_avg'] = {
        'Precision': weighted_precision,
        'Recall': weighted_recall,
        'F1-score': weighted_f1,
        'Accuracy': accuracy
    }
    
    return metrics

def calculate_multi_label_metrics(true_labels_list, pred_labels_list, unique_labels):
    """Calculate comprehensive multi-label metrics"""
    y_true = np.zeros((len(true_labels_list), len(unique_labels)))
    y_pred = np.zeros((len(pred_labels_list), len(unique_labels)))
    
    # Convert labels to binary matrix
    for i, (true_set, pred_set) in enumerate(zip(true_labels_list, pred_labels_list)):
        for label in true_set:
            if pd.notna(label) and label in unique_labels:
                y_true[i, unique_labels.index(label)] = 1
        pred_labels = parse_prediction(pred_set)[1]
        for label in pred_labels:
            if label in unique_labels:
                y_pred[i, unique_labels.index(label)] = 1
    
    # Calculate various multi-label metrics
    hamming = hamming_loss(y_true, y_pred)
    subset_accuracy = accuracy_score(y_true, y_pred)
    
    # Calculate partial matches
    partial_matches = sum(np.sum(y_true[i] * y_pred[i]) > 0 for i in range(len(y_true)))
    partial_accuracy = partial_matches / len(y_true)
    
    # Calculate Unknown statistics
    unknown_predictions = sum('Unknown' in parse_prediction(pred)[1] for pred in pred_labels_list)
    unknown_true = sum('Unknown' in true_set for true_set in true_labels_list)
    unknown_rate = unknown_predictions / len(pred_labels_list)
    unknown_accuracy = accuracy_score(
        [1 if 'Unknown' in true_set else 0 for true_set in true_labels_list],
        [1 if 'Unknown' in parse_prediction(pred)[1] else 0 for pred in pred_labels_list]
    )
    
    # Calculate label cardinality
    label_cardinality_true = np.mean(np.sum(y_true, axis=1))
    label_cardinality_pred = np.mean(np.sum(y_pred, axis=1))
    
    # Calculate exact match ratio for non-Unknown predictions
    non_unknown_matches = 0
    total_non_unknown = 0
    for i in range(len(y_true)):
        if not ('Unknown' in true_labels_list[i] or 'Unknown' in parse_prediction(pred_labels_list[i])[1]):
            total_non_unknown += 1
            if np.array_equal(y_true[i], y_pred[i]):
                non_unknown_matches += 1
    
    exact_match_ratio = non_unknown_matches / total_non_unknown if total_non_unknown > 0 else 0
    
    metrics = {
        'Hamming_Loss': hamming,
        'Subset_Accuracy': subset_accuracy,
        'Partial_Match_Accuracy': partial_accuracy,
        'Label_Cardinality_True': label_cardinality_true,
        'Label_Cardinality_Predicted': label_cardinality_pred,
        'Unknown_Rate': unknown_rate,
        'Unknown_Accuracy': unknown_accuracy,
        'Unknown_True_Rate': unknown_true / len(true_labels_list),
        'Unknown_Pred_Rate': unknown_predictions / len(pred_labels_list),
        'Exact_Match_Ratio_Non_Unknown': exact_match_ratio,
        'Average_Labels_Per_Prediction': label_cardinality_pred,
        'Label_Diversity': len(set(label for labels in pred_labels_list for label in parse_prediction(labels)[1]))
    }
    
    # Calculate confusion matrices per class
    confusion_matrices = multilabel_confusion_matrix(y_true, y_pred)
    
    # Per-class metrics
    for i, label in enumerate(unique_labels):
        tn, fp, fn, tp = confusion_matrices[i].ravel()
        precision = tp / (tp + fp) if (tp + fp) > 0 else 0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0
        f1 = 2 * tp / (2 * tp + fp + fn) if (2 * tp + fp + fn) > 0 else 0
        
        metrics[f'{label}_Precision'] = precision
        metrics[f'{label}_Recall'] = recall
        metrics[f'{label}_F1'] = f1
    
    return metrics

def calculate_class_deltas(single_label_results, unique_labels, output_dir):
    """Calculate and visualize per-class performance differences between models"""
    print_progress("Calculating class-wise performance deltas...")
    
    metrics = ['Precision', 'Recall', 'F1-score', 'Support']
    standard_model = 'ConflLlama-Standard'
    alt_model = 'ConflLlama-Alt'
    
    # Calculate deltas for each class and metric
    deltas = {}
    for label in unique_labels:
        deltas[label] = {}
        for metric in metrics:
            standard_value = single_label_results[standard_model][label][metric]
            alt_value = single_label_results[alt_model][label][metric]
            if metric != 'Support':  # Calculate delta as percentage for non-support metrics
                delta = ((alt_value - standard_value) / standard_value * 100 
                        if standard_value != 0 else float('inf'))
            else:  # Keep support as absolute numbers
                delta = alt_value - standard_value
            deltas[label][metric] = delta
    
    # Create DataFrame for visualization
    delta_df = pd.DataFrame(deltas).T
    
    # Create heatmap for percentage changes
    plt.figure(figsize=(15, len(unique_labels) * 0.5))
    percentage_metrics = ['Precision', 'Recall', 'F1-score']
    delta_df_pct = delta_df[percentage_metrics]
    
    sns.heatmap(delta_df_pct, 
                cmap='RdYlBu',
                center=0,
                annot=True, 
                fmt='.1f',
                cbar_kws={'label': 'Percentage Change (%)'})
    
    plt.title('Class-wise Performance Changes\n(Alternative vs Standard Model)')
    plt.ylabel('Attack Type')
    plt.xlabel('Metric')
    plt.tight_layout()
    
    save_path = os.path.join(output_dir, 'class_performance_deltas_percentage.png')
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    
    # Calculate and save additional statistics
    delta_stats = {
        'largest_improvement': {
            'class': delta_df['F1-score'].idxmax(),
            'value': delta_df['F1-score'].max(),
            'metric': 'F1-score'
        },
        'largest_decline': {
            'class': delta_df['F1-score'].idxmin(),
            'value': delta_df['F1-score'].min(),
            'metric': 'F1-score'
        },
        'unknown_delta': {
            'precision_change': delta_df.loc['Unknown', 'Precision'],
            'recall_change': delta_df.loc['Unknown', 'Recall'],
            'f1_change': delta_df.loc['Unknown', 'F1-score']
        },
        'average_changes': {
            'Precision': delta_df['Precision'].mean(),
            'Recall': delta_df['Recall'].mean(),
            'F1-score': delta_df['F1-score'].mean()
        }
    }
    
    # Save delta statistics to CSV
    delta_stats_df = pd.DataFrame([delta_stats])
    delta_stats_df.to_csv(os.path.join(output_dir, 'class_delta_statistics.csv'))
    
    return deltas, delta_stats

def calculate_roc_curves(df, model_cols, true_col, unique_labels):
    """Calculate ROC curves for all models"""
    roc_data = {}
    class_roc_data = {}
    
    for model_col in model_cols:
        print_progress(f"Processing ROC curve for {model_col}...", indent=2)
        all_fpr = []
        all_tpr = []
        all_auc = []
        class_data = {}
        
        for label in unique_labels:
            y_true = (df[true_col] == label).astype(int)
            y_score = df[model_col].apply(lambda x: get_probabilities_for_label(x, label))
            
            try:
                fpr, tpr, _ = roc_curve(y_true, y_score)
                roc_auc = auc(fpr, tpr)
                
                all_fpr.append(fpr)
                all_tpr.append(tpr)
                all_auc.append(roc_auc)
                
                class_data[label] = {
                    'fpr': fpr,
                    'tpr': tpr,
                    'auc': roc_auc
                }
            except Exception as e:
                print_progress(f"Error calculating ROC for {label}: {str(e)}", indent=3)
                continue
        
        if all_auc:
            mean_auc = np.mean(all_auc)
            base_fpr = np.linspace(0, 1, 100)
            mean_tpr = np.mean([np.interp(base_fpr, fpr, tpr) 
                              for fpr, tpr in zip(all_fpr, all_tpr)], axis=0)
            
            roc_data[model_col] = {
                'fpr': base_fpr,
                'tpr': mean_tpr,
                'auc': mean_auc
            }
            class_roc_data[model_col] = class_data
    
    return roc_data, class_roc_data

def plot_roc_curves(roc_data, model_name_mapping, colors, output_dir):
    """Plot ROC curves with consistent colors"""
    plt.figure(figsize=(10, 8))
    
    plt.plot([0, 1], [0, 1], linestyle='--', color='black', alpha=0.8, label='Random')
    
    for model_col, data in roc_data.items():
        base_model_name = model_col.replace('_predictions', '')
        display_name = model_name_mapping.get(base_model_name, base_model_name)
        color = colors.get(display_name, 'gray')
        
        plt.plot(data['fpr'], data['tpr'], 
                label=f"{display_name} (AUC = {data['auc']:.3f})",
                color=color,
                linewidth=2)
    
    plt.grid(True, alpha=0.3)
    plt.xlabel('False Positive Rate', fontsize=14)
    plt.ylabel('True Positive Rate', fontsize=14)
    plt.title('Overall ROC Curves - All Models', pad=20, fontsize=16)
    plt.legend(loc='lower right', fontsize=12)
    plt.grid(True, linestyle='--', alpha=0.7)
    plt.tick_params(axis='both', which='major', labelsize=12)
    plt.tight_layout()
    
    save_path = os.path.join(output_dir, 'overall_roc_curves.png')
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()

def plot_class_roc_curves(class_roc_data, model_name_mapping, colors, unique_labels, output_dir):
    """Plot class-wise ROC curves"""
    print_progress("Creating class-wise ROC curve plots...", indent=1)
    
    class_roc_dir = os.path.join(output_dir, 'class_roc_curves')
    os.makedirs(class_roc_dir, exist_ok=True)
    
    for label in unique_labels:
        plt.figure(figsize=(10, 8))
        plt.plot([0, 1], [0, 1], linestyle='--', color='black', alpha=0.8, label='Random')
        
        for model_col, class_data in class_roc_data.items():
            if label in class_data:
                base_model_name = model_col.replace('_predictions', '')
                display_name = model_name_mapping.get(base_model_name, base_model_name)
                color = colors.get(display_name, 'gray')
                
                data = class_data[label]
                plt.plot(data['fpr'], data['tpr'],
                        label=f"{display_name} (AUC = {data['auc']:.3f})",
                        color=color,
                        linewidth=2)
        
        plt.grid(True, alpha=0.3)
        plt.xlabel('False Positive Rate', fontsize=14)
        plt.ylabel('True Positive Rate', fontsize=14)
        plt.title(f'ROC Curves for {label}', pad=20, fontsize=16)
        plt.legend(loc='lower right', fontsize=12)
        plt.grid(True, linestyle='--', alpha=0.7)
        plt.tick_params(axis='both', which='major', labelsize=12)
        plt.tight_layout()
        
        safe_label = sanitize_filename(label)
        save_path = os.path.join(class_roc_dir, f'roc_curves_{safe_label}.png')
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.close()

def plot_confusion_matrix(y_true, y_pred, labels, title, filename):
    """Plot confusion matrix with percentages and raw counts"""
    # Create two confusion matrices: one normalized, one raw
    cm_raw = pd.crosstab(y_true, y_pred)
    cm_norm = pd.crosstab(y_true, y_pred, normalize='index')
    
    # Ensure all labels are present
    for label in labels:
        if label not in cm_raw.index:
            cm_raw.loc[label] = 0
            cm_norm.loc[label] = 0
        if label not in cm_raw.columns:
            cm_raw[label] = 0
            cm_norm[label] = 0
    
    # Reindex to ensure consistent order
    cm_raw = cm_raw.reindex(index=labels, columns=labels, fill_value=0)
    cm_norm = cm_norm.reindex(index=labels, columns=labels, fill_value=0)
    
    # Create a figure with two subplots
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(24, 10))
    
    # Plot normalized confusion matrix
    sns.heatmap(cm_norm, annot=True, fmt='.3f', cmap='YlOrRd', square=True,
                cbar_kws={'label': 'Normalized Frequency'}, 
                annot_kws={'size': 10}, ax=ax1)
    
    # Adjust colorbar label size after plotting
    ax1.collections[0].colorbar.ax.set_ylabel('Normalized Frequency', size=12)
    
    ax1.set_title(f'{title}\n(Normalized by True Labels)', pad=20, size=16)
    ax1.set_ylabel('True Label', size=14)
    ax1.set_xlabel('Predicted Label', size=14)
    plt.setp(ax1.get_xticklabels(), rotation=45, ha='right', size=12)
    plt.setp(ax1.get_yticklabels(), rotation=0, size=12)
    
    # Plot raw counts confusion matrix
    sns.heatmap(cm_raw, annot=True, fmt='d', cmap='YlOrRd', square=True,
                cbar_kws={'label': 'Count'}, 
                annot_kws={'size': 10}, ax=ax2)
    
    # Adjust colorbar label size after plotting
    ax2.collections[0].colorbar.ax.set_ylabel('Count', size=12)
    
    ax2.set_title(f'{title}\n(Raw Counts)', pad=20, size=16)
    ax2.set_ylabel('True Label', size=14)
    ax2.set_xlabel('Predicted Label', size=14)
    plt.setp(ax2.get_xticklabels(), rotation=45, ha='right', size=12)
    plt.setp(ax2.get_yticklabels(), rotation=0, size=12)
    
    plt.tight_layout()
    plt.savefig(filename, dpi=300, bbox_inches='tight')
    plt.close()

def analyze_error_patterns(true_labels, pred_labels, unique_labels, output_dir):
    """Analyze and visualize error patterns"""
    error_patterns = defaultdict(int)
    error_severity = defaultdict(list)
    
    for true, pred in zip(true_labels, pred_labels):
        if true != pred:
            error_patterns[(true, pred)] += 1
            severity = 1.0
            error_severity[(true, pred)].append(severity)
    
    error_df = pd.DataFrame(
        [(true, pred, count, np.mean(error_severity[(true, pred)])) 
         for (true, pred), count in error_patterns.items()],
        columns=['True Label', 'Predicted Label', 'Count', 'Avg Severity']
    )
    
    error_df = error_df.sort_values('Count', ascending=False)
    error_df.to_csv(os.path.join(output_dir, 'error_patterns.csv'), index=False)
    
    top_n = 20
    plt.figure(figsize=(15, 8))
    top_errors = error_df.head(top_n)
    error_labels = [f"{true[:10]}→{pred[:10]}" for true, pred in 
                   zip(top_errors['True Label'], top_errors['Predicted Label'])]
    
    plt.bar(range(len(top_errors)), top_errors['Count'])
    plt.xticks(range(len(top_errors)), error_labels, rotation=45, ha='right', fontsize=12)
    plt.title('Top Error Patterns', pad=20, fontsize=16)
    plt.xlabel('Error Type', fontsize=14)
    plt.ylabel('Count', fontsize=14)
    plt.tick_params(axis='both', which='major', labelsize=12)
    plt.tight_layout()
    
    plt.savefig(os.path.join(output_dir, 'top_error_patterns.png'), dpi=300, bbox_inches='tight')
    plt.close()
    
    return error_df

def create_comparative_tables(single_label_results, multi_label_results, output_dir):
    """Create detailed comparative tables with Unknown metrics"""
    # Per-class performance comparison
    class_metrics = defaultdict(list)
    models = list(single_label_results.keys())
    
    for label in next(iter(single_label_results.values())).keys():
        if label not in ['macro_avg', 'weighted_avg']:
            row = {'Label': label}
            for model in models:
                for metric in ['Precision', 'Recall', 'F1-score', 'Support']:
                    row[f'{model}_{metric}'] = single_label_results[model][label][metric]
            class_metrics['rows'].append(row)
    
    df_class = pd.DataFrame(class_metrics['rows'])
    
    # Add percentage differences
    base_model = models[0]
    compare_model = models[1]
    metrics = ['Precision', 'Recall', 'F1-score']
    
    for metric in metrics:
        df_class[f'{metric}_Diff_Pct'] = (
            (df_class[f'{compare_model}_{metric}'] - df_class[f'{base_model}_{metric}']) /
            df_class[f'{base_model}_{metric}'] * 100
        ).fillna(0)
    
    df_class.to_csv(os.path.join(output_dir, 'class_performance_comparison.csv'), index=False)
    
    # Create summary comparison table
    summary_data = []
    for model in models:
        row = {'Model': model}
        
        # Add macro averages
        for metric in ['Precision', 'Recall', 'F1-score', 'Accuracy']:
            row[f'Macro_{metric}'] = single_label_results[model]['macro_avg'][metric]
            row[f'Weighted_{metric}'] = single_label_results[model]['weighted_avg'][metric]
        
        # Add multi-label metrics
        multi_metrics = multi_label_results[model]
        for metric, value in multi_metrics.items():
            row[metric] = value
        
        summary_data.append(row)
    
    df_summary = pd.DataFrame(summary_data)
    df_summary.to_csv(os.path.join(output_dir, 'model_comprehensive_comparison.csv'), index=False)
    
    # Create HTML summary with formatting
    def style_table(val):
        color = 'red' if isinstance(val, float) and val < 0 else 'green'
        return f'color: {color}'
    
    styled_summary = df_summary.style\
        .format(precision=3)\
        .background_gradient(cmap='YlOrRd', subset=df_summary.columns[1:])\
        .applymap(style_table, subset=df_summary.columns[1:])\
        .to_html()
    
    with open(os.path.join(output_dir, 'model_comprehensive_comparison.html'), 'w') as f:
        f.write(styled_summary)
    
    return df_summary, df_class
    
def analyze_error_patterns(true_labels, pred_labels, unique_labels, output_dir):
    """Analyze and visualize error patterns"""
    # Create output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)
    
    error_patterns = defaultdict(int)
    error_severity = defaultdict(list)
    
    for true, pred in zip(true_labels, pred_labels):
        if true != pred:
            error_patterns[(true, pred)] += 1
            # Calculate error severity (e.g., how semantically different the predictions are)
            severity = 1.0  # Basic severity score, could be enhanced
            error_severity[(true, pred)].append(severity)
    
    # Convert to DataFrame for visualization
    error_df = pd.DataFrame(
        [(true, pred, count, np.mean(error_severity[(true, pred)])) 
         for (true, pred), count in error_patterns.items()],
        columns=['True Label', 'Predicted Label', 'Count', 'Avg Severity']
    )
    
    # Sort by count and get top errors
    error_df = error_df.sort_values('Count', ascending=False)
    
    # Save error patterns
    error_df.to_csv(os.path.join(output_dir, 'error_patterns.csv'), index=False)
    
    # Visualize top N error patterns
    top_n = 20
    plt.figure(figsize=(15, 8))
    top_errors = error_df.head(top_n)
    error_labels = [f"{true[:10]}→{pred[:10]}" for true, pred in 
                   zip(top_errors['True Label'], top_errors['Predicted Label'])]
    
    plt.bar(range(len(top_errors)), top_errors['Count'])
    plt.xticks(range(len(top_errors)), error_labels, rotation=45, ha='right')
    plt.title('Top Error Patterns')
    plt.xlabel('Error Type')
    plt.ylabel('Count')
    plt.tight_layout()
    
    plt.savefig(os.path.join(output_dir, 'top_error_patterns.png'), dpi=300, bbox_inches='tight')
    plt.close()
    
    return error_df
    
def calculate_class_deltas(single_label_results, unique_labels, output_dir):
    """Calculate and visualize per-class performance differences between models"""
    print_progress("Calculating class-wise performance deltas...")
    
    metrics = ['Precision', 'Recall', 'F1-score']
    standard_model = 'ConflLlama-Standard'
    alt_model = 'ConflLlama-Alt'
    
    deltas = {}
    for label in unique_labels:
        deltas[label] = {}
        for metric in metrics:
            standard_value = single_label_results[standard_model][label][metric]
            alt_value = single_label_results[alt_model][label][metric]
            delta = alt_value - standard_value
            deltas[label][metric] = delta
    
    delta_df = pd.DataFrame(deltas).T
    
    plt.figure(figsize=(12, len(unique_labels) * 0.5))
    
    # Create heatmap with basic settings
    heatmap = sns.heatmap(delta_df, 
                         cmap='RdYlBu',
                         center=0,
                         annot=True, 
                         fmt='.3f',
                         cbar_kws={'label': 'Performance Delta (Alt - Standard)'})
    
    # Adjust colorbar label size after plotting
    heatmap.collections[0].colorbar.ax.set_ylabel('Performance Delta (Alt - Standard)', size=12)
    
    plt.title('Class-wise Performance Deltas\n(Alternative - Standard Model)', 
             pad=20, fontsize=16)
    plt.ylabel('Attack Type', fontsize=14)
    plt.xlabel('Metric', fontsize=14)
    plt.tick_params(axis='both', which='major', labelsize=12)
    plt.tight_layout()
    
    save_path = os.path.join(output_dir, 'class_performance_deltas.png')
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    
    # Calculate statistics and return results
    delta_stats = {
        'largest_improvement': {
            'class': delta_df['F1-score'].idxmax(),
            'value': delta_df['F1-score'].max()
        },
        'largest_decline': {
            'class': delta_df['F1-score'].idxmin(),
            'value': delta_df['F1-score'].min()
        },
        'unknown_delta': {
            'precision_change': delta_df.loc['Unknown', 'Precision'] if 'Unknown' in delta_df.index else 0,
            'recall_change': delta_df.loc['Unknown', 'Recall'] if 'Unknown' in delta_df.index else 0,
            'f1_change': delta_df.loc['Unknown', 'F1-score'] if 'Unknown' in delta_df.index else 0
        },
        'average_changes': {metric: delta_df[metric].mean() for metric in metrics}
    }
    
    return deltas, delta_stats

def main():
    try:
        start_time = time.time()
        print_progress("Starting analysis...")
        
        script_dir = os.path.dirname(os.path.abspath(__file__))
        output_dir = os.path.join(script_dir, "results")
        os.makedirs(output_dir, exist_ok=True)
        
        print_progress("Reading data file...")
        results_file = os.path.join(script_dir, "final_results_hf.co_shreyasmeher_ConflLlama-Alt_Q8_0_cutoff_20170101.csv")
        df = pd.read_csv(results_file, low_memory=False)
        
        # Updated model columns for the two models we want to compare
        model_cols = [
            'hf.co/shreyasmeher/ConflLlama:Q8_0_predictions',
            'hf.co/shreyasmeher/ConflLlama-Alt:Q8_0_predictions'
        ]
        
        # Updated model name mapping
        model_name_mapping = {
            'hf.co/shreyasmeher/ConflLlama:Q8_0': 'ConflLlama-Standard',
            'hf.co/shreyasmeher/ConflLlama-Alt:Q8_0': 'ConflLlama-Alt'
        }
        
        print_progress("Extracting unique labels...")
        unique_labels = sorted(set(
            label for col in ['attacktype1_txt', 'attacktype2_txt', 'attacktype3_txt']
            for label in df[col].dropna().unique()
        ))
        
        single_label_results = {}
        multi_label_results = {}
        all_primary_preds = {}  # Store predictions for error analysis
        
        # Get color scheme
        colors = get_model_colors()
        
        # Calculate ROC curves
        print_progress("\nCalculating ROC curves...")
        roc_data, class_roc_data = calculate_roc_curves(df, model_cols, 'attacktype1_txt', unique_labels)
        
        if roc_data:
            print_progress("Plotting ROC curves...")
            plot_roc_curves(roc_data, model_name_mapping, colors, output_dir)
            plot_class_roc_curves(class_roc_data, model_name_mapping, colors, unique_labels, output_dir)
        
        for model_col in model_cols:
            try:
                base_model_name = model_col.replace('_predictions', '')
                display_name = model_name_mapping.get(base_model_name, base_model_name)
                print_progress(f"\nProcessing {display_name}...")
                
                primary_preds = []
                all_preds = []
                for pred in df[model_col]:
                    primary, all_labels = parse_prediction(pred)
                    primary_preds.append(primary)
                    all_preds.append(all_labels)
                
                all_primary_preds[display_name] = primary_preds
                
                print_progress("Extracting gold standard labels...", indent=1)
                true_primary = df['attacktype1_txt']
                true_multi = df.apply(get_gold_standard_labels, axis=1)
                
                print_progress("Calculating metrics...", indent=1)
                single_metrics = calculate_single_label_metrics(
                    true_primary, primary_preds, unique_labels
                )
                
                multi_metrics = calculate_multi_label_metrics(
                    true_multi, all_preds, unique_labels
                )
                
                single_label_results[display_name] = single_metrics
                multi_label_results[display_name] = multi_metrics
                
                # Create model-specific output directory
                model_output_dir = os.path.join(output_dir, f'analysis_{sanitize_filename(display_name)}')
                os.makedirs(model_output_dir, exist_ok=True)
                
                # Plot confusion matrices
                plot_confusion_matrix(
                    true_primary, 
                    primary_preds, 
                    unique_labels,
                    f'Confusion Matrix - {display_name}',
                    os.path.join(model_output_dir, 'confusion_matrix.png')
                )
                
                # Analyze error patterns
                print_progress("Analyzing error patterns...", indent=1)
                error_df = analyze_error_patterns(
                    true_primary, primary_preds, unique_labels,
                    model_output_dir
                )
                
            except Exception as e:
                print_progress(f"Error processing {display_name}: {str(e)}", indent=1)
                print_progress(f"Traceback: {traceback.format_exc()}", indent=2)
                continue
        
        if single_label_results and multi_label_results:
            print_progress("\nCreating summary visualizations...")
            
            # Calculate and visualize class-wise deltas
            deltas, delta_stats = calculate_class_deltas(single_label_results, unique_labels, output_dir)
            print_progress(f"Largest performance improvement: {delta_stats['largest_improvement']['class']} ({delta_stats['largest_improvement']['value']:.3f})")
            print_progress(f"Largest performance decline: {delta_stats['largest_decline']['class']} ({delta_stats['largest_decline']['value']:.3f})")
            print_progress(f"Average F1-score delta: {delta_stats['average_changes']['F1-score']:.3f}")
            
            # Generate Unknown-specific analysis
            print_progress("\nAnalyzing Unknown predictions...")
            for model_name, metrics in multi_label_results.items():
                print_progress(f"\n{model_name} Unknown statistics:", indent=1)
                print_progress(f"Unknown prediction rate: {metrics['Unknown_Rate']:.3f}", indent=2)
                print_progress(f"Unknown prediction accuracy: {metrics['Unknown_Accuracy']:.3f}", indent=2)
                print_progress(f"True Unknown rate: {metrics['Unknown_True_Rate']:.3f}", indent=2)
            
            # Create comparative tables with detailed metrics
            print_progress("\nCreating comparative tables...")
            summary_df, class_df = create_comparative_tables(single_label_results, multi_label_results, output_dir)
            
            print_progress("\nGenerating final summary report...")
            with open(os.path.join(output_dir, 'analysis_summary.txt'), 'w') as f:
                f.write("Model Comparison Analysis Summary\n")
                f.write("================================\n\n")
                
                f.write("1. Overall Performance\n")
                f.write("-----------------\n")
                for model, metrics in multi_label_results.items():
                    f.write(f"\n{model}:\n")
                    f.write(f"  - Macro F1-score: {single_label_results[model]['macro_avg']['F1-score']:.3f}\n")
                    f.write(f"  - Unknown handling accuracy: {metrics['Unknown_Accuracy']:.3f}\n")
                    f.write(f"  - Multi-label accuracy: {metrics['Subset_Accuracy']:.3f}\n")
                
                f.write("\n2. Key Improvements\n")
                f.write("----------------\n")
                f.write(f"Largest improvement: {delta_stats['largest_improvement']['class']} ")
                f.write(f"({delta_stats['largest_improvement']['value']:.1f}%)\n")
                f.write(f"Average improvement: {delta_stats['average_changes']['F1-score']:.1f}%\n")
                
                f.write("\n3. Areas for Attention\n")
                f.write("-------------------\n")
                f.write(f"Largest decline: {delta_stats['largest_decline']['class']} ")
                f.write(f"({delta_stats['largest_decline']['value']:.1f}%)\n")
                
                f.write("\n4. Unknown Prediction Analysis\n")
                f.write("--------------------------\n")
                f.write(f"Unknown prediction delta: {delta_stats['unknown_delta']['f1_change']:.1f}%\n")
                
        end_time = time.time()
        execution_time = end_time - start_time
        print_progress(f"\nAnalysis complete! Total execution time: {execution_time:.2f} seconds")
        
    except Exception as e:
        print_progress(f"Fatal error in main execution: {str(e)}")
        print_progress(f"Traceback: {traceback.format_exc()}")
        raise

if __name__ == "__main__":
    main()